import argparse
import os
import json
import ast
from prompt.omniActPrompt import OMNIACTPROMPT_FOROSATLAS, OMNIACTPROMPT_FORUITARAS, OMNIACT_FROGUIR1
import sys
from tqdm import tqdm
import re
sys.path.append("./")
from utils.logging_utils import setup_logger_to_stdout
from utils.schema.GUI_OWL.common import pil_to_base64, message_translate
from preprocess_base import BasePreProcess
logger = setup_logger_to_stdout()


def parse_args(args=None, namespace=None):
    parser = argparse.ArgumentParser(description='Origin Dataset To Json')
    parser.add_argument('--dataset_name', type=str, default="OmniAct",
                        help='dataset name')
    parser.add_argument('--dataset_type', type=str, default='all', help='dataset type')
    parser.add_argument('--dataset_path', type=str, default="/data3/cpz/datasets/omniact/",
                        help='dataset path')
    parser.add_argument('--model_name', type=str, default="OS_ATLAS",
                        help='model name')
    parser.add_argument('--save_path', type=str, default="",
                        help='save path')
    return parser.parse_args()

class OmniActPreProcess(BasePreProcess):
    def __init__(self, dataset_type, dataset_path, dataset_name, save_path, model_name):
        super().__init__(dataset_path, dataset_name, save_path, model_name)
        self.dataset_type = dataset_type
        self.dataset_path = dataset_path
        self.dataset_name = dataset_name
        self.model_name = model_name
        self.test_path = os.path.join(self.dataset_path, 'test.json')

    def OS_ATLAS(self):
        sample = super().OS_ATLAS()
        def actionMapping(action, image_size):
            if action.split("(")[0].lower() in ['pyautogui.click', 'pyautogui.rightclick', 'pyautogui.doubleclick'] and re.search(r"\d", action):
                x, y = self._extract_coordinates(action)
                x, y = x / image_size[0] * 1000, y / image_size[1] * 1000
                return f"CLICK <point>[[{int(x)}, {int(y)}]]</point>" 
            elif action.split("(")[0].lower() in ['pyautogui.moveTo'] and re.search(r"\d", action):
                x, y = self._extract_coordinates(action)
                x, y = x / image_size[0] * 1000, y / image_size[1] * 1000
                return f"MOVETO <point>[[{int(x)}, {int(y)}]]</point>" 
            elif action.split("(")[0].lower() in ['pyautogui.press']:
                return "PRESS_SPACE"
            else:
                return None
      
        data_web = []
        data_desktop = []
        data = self._get_data()
        for key in tqdm(data.keys()):
            from copy import deepcopy
            record = deepcopy(sample)
            data_item = data[key]
            record['goal'] = self._get_task(data_item['task'])
            image_path = data_item['image']
            if 'web' in image_path:
                image_path = re.sub(r"screen_(\d+)", r"screen\1", image_path)
            record['images'] = [os.path.join(self.dataset_path, image_path)]
            try: 
                record['image_size'] = [self._get_image_size(record['images'][0])]
            except Exception as e:
                logger.error(e)
                continue
            action_orig = self._get_action(data_item['task'])
            action_traslate = actionMapping(action_orig, record['image_size'][0])
            if action_traslate is None:
                continue
            record['label'] = "action:\n"+action_traslate
            record['messages'][1]['content'] = action_traslate
            record['messages'][0]['content'] = OMNIACTPROMPT_FOROSATLAS.replace("{finalGoal}", record['goal'])
            if 'web' in record['images'][0]:
                data_web.append(record)
            else:
                data_desktop.append(record)
        if not os.path.exists(self.save_path):
            os.makedirs(self.save_path)
        self.saveJson(data_web, os.path.join(self.save_path, self.dataset_type+"_web_"+self.model_name.lower()+'.json'))
        self.saveJson(data_desktop, os.path.join(self.save_path, self.dataset_type+"_desktop_"+self.model_name.lower()+'.json'))
        logger.info("Finished")
    
    def UI_TARS(self):
        sample = super().UI_TARS()
        def actionMapping(action, image_size):
            if action.split("(")[0] in ['pyautogui.click', 'pyautogui.rightclick', 'pyautogui.doubleClick'] and re.search(r"\d", action):
                x, y = self._extract_coordinates(action)
                if "1.5" in self.model_name:
                    return f"click(start_box='({x},{y})')" 
                x, y = x / image_size[0] * 1000, y / image_size[1] * 1000
                return f"click(start_box='({x},{y})')" 
            elif action.split("(")[0] in ['pyautogui.moveTo'] and re.search(r"\d", action):
                x, y = self._extract_coordinates(action)
                if "1.5" in self.model_name:
                    return f"moveto(start_box='({x},{y})')" 
                x, y = x / image_size[0] * 1000, y / image_size[1] * 1000
                return f"moveto(start_box='({x},{y})')" 
            elif action.split("(")[0].lower() in ['pyautogui.press']:
                return "press_space()"
            else:
                return None
        

        data_web = []
        data_desktop = []
        data = self._get_data()
        for key in tqdm(data.keys()):
            from copy import deepcopy
            record = deepcopy(sample)
            data_item = data[key]
            record['goal'] = self._get_task(data_item['task'])
            image_path = data_item['image']
            if 'web' in image_path:
                image_path = re.sub(r"screen_(\d+)", r"screen\1", image_path)
            record['images'] = [os.path.join(self.dataset_path, image_path)]
            try: 
                record['image_size'] = [self._get_image_size(record['images'][0])]
            except Exception as e:
                logger.error(e)
                continue
            action_orig = self._get_action(data_item['task'])
            action_traslate = actionMapping(action_orig, record['image_size'][0])
            if action_traslate is None:
                continue
            record['label'] = action_traslate
            record['messages'][1]['content'][0]['text'] = OMNIACTPROMPT_FORUITARAS.replace("{instruction}", record['goal'])
            record['messages'].extend([
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "image",
                            "image": record['images'][0]
                        }
                    ]
                }
            ])
            record['label'] = f"Thought: {record['goal']}\nAction: {action_traslate}"
            if 'web' in record['images'][0]:
                data_web.append(record)
            else:
                data_desktop.append(record)
        if not os.path.exists(self.save_path):
            os.makedirs(self.save_path)
        self.saveJson(data_web, os.path.join(self.save_path, self.dataset_type+"_web_"+self.model_name.lower()+'.json'))
        self.saveJson(data_desktop, os.path.join(self.save_path, self.dataset_type+"_desktop_"+self.model_name.lower()+'.json'))
        logger.info("Finished")
          

    def GUI_R1(self):
        sample = super().GUI_R1()
        def actionMapping(action):
            if action.split("(")[0] in ['pyautogui.click', 'pyautogui.rightclick', 'pyautogui.doubleClick']:
                x, y = self._extract_coordinates(action)
                action_name = 'click'
                point = [int(x), int(y)]
                input_text = 'no input text'
            elif action.split("(")[0] in ['pyautogui.moveTo']:
                x, y = self._extract_coordinates(action)
                action_name = 'moveto'
                point = [int(x), int(y)]
                input_text = 'no input text'
            else:
                action_name = "press_pgdn"
                point = [-100, -100]
                input_text = 'no input text'
            formatted_action = [{
                'action': action_name,
                'point': point,
                'input_text': input_text
            }]
            return str(formatted_action)

        data_web = []
        data_desktop = []
        data = self._get_data()
        for key in tqdm(data.keys()):
            from copy import deepcopy
            record = deepcopy(sample)
            data_item = data[key]
            record['goal'] = self._get_task(data_item['task'])
            image_path = data_item['image']
            if 'web' in image_path:
                image_path = re.sub(r"screen_(\d+)", r"screen\1", image_path)
            record['images'] = [os.path.join(self.dataset_path, image_path)]
            try: 
                record['image_size'] = [self._get_image_size(record['images'][0])]
            except Exception as e:
                logger.error(e)
                continue
            action_orig = self._get_action(data_item['task'])
            action_traslate = actionMapping(action_orig)
            record['label'] = "<think></think><answer>"+action_traslate+"</answer>"
            record['messages'][0]['content'][0]['image'] = record['images'][0]  
            record['messages'][0]['content'][1]['text'] = '<image>\n' + OMNIACT_FROGUIR1.replace("{text}", record['goal'])
      
            if 'web' in record['images'][0]:
                data_web.append(record)
            else:
                data_desktop.append(record)
        if not os.path.exists(self.save_path):
            os.makedirs(self.save_path)
        self.saveJson(data_web, os.path.join(self.save_path, self.dataset_type+"_web_"+self.model_name.lower()+'.json'))
        self.saveJson(data_desktop, os.path.join(self.save_path, self.dataset_type+"_desktop_"+self.model_name.lower()+'.json'))
        logger.info("Finished")
            

    def Agent_CPM(self):
        sample = super().Agent_CPM()
        def actionMapping(action):
            if action.split("(")[0].lower() in ['pyautogui.click', 'pyautogui.rightclick', 'pyautogui.doubleclick', 'pyautogui.moveto'] and re.search(r"\d", action):
                x, y = self._extract_coordinates(action)
                return str({"thought":"", "POINT": [x, y]})
            elif action.split("(")[0].lower() in ['pyautogui.press']:
                return str({"thought":"", "PRESS": "SPACE"})
            else:
                return None
           
        
        from prompt.omniActPrompt import AGENT_CPM_SYSTEM_PROMPT
        ACTION_SCHEMA = json.load(open('/Agent_ScanKit/utils/schema/agentCPMSchema.json', encoding="utf-8"))
        items = list(ACTION_SCHEMA.items())
        insert_index = 3
        items.insert(insert_index, ("required", ["thought"])) 
        ACTION_SCHEMA = dict(items)
        AGENT_CPM_SYSTEM_PROMPT = AGENT_CPM_SYSTEM_PROMPT.replace("ACTION_SCHEMA", str(ACTION_SCHEMA))

        data_web = []
        data_desktop = []
        data = self._get_data()
        for key in tqdm(data.keys()):
            from copy import deepcopy
            record = deepcopy(sample)
            data_item = data[key]
            record['goal'] = self._get_task(data_item['task'])
            image_path = data_item['image']
            if 'web' in image_path:
                image_path = re.sub(r"screen_(\d+)", r"screen\1", image_path)
            record['images'] = [os.path.join(self.dataset_path, image_path)]
            try: 
                record['image_size'] = [self._get_image_size(record['images'][0])]
            except Exception as e:
                logger.error(e)
                continue
            action_orig = self._get_action(data_item['task'])
            action_traslate = actionMapping(action_orig)
            record['label'] = action_traslate
            record['messages'][0]['content'][0] = record['messages'][0]['content'][0].replace("text_prompt", record['goal'])
            record['system_prompt'] = AGENT_CPM_SYSTEM_PROMPT
            if 'web' in record['images'][0]:
                data_web.append(record)
            else:
                data_desktop.append(record)
        if not os.path.exists(self.save_path):
            os.makedirs(self.save_path)
        self.saveJson(data_web, os.path.join(self.save_path, self.dataset_type+"_web_"+self.model_name.lower()+'.json'))
        self.saveJson(data_desktop, os.path.join(self.save_path, self.dataset_type+"_desktop_"+self.model_name.lower()+'.json'))
        logger.info("Finished")


    def Aguvis(self):
        sample = super().Aguvis()
        def actionMapping(action):
            if action.split("(")[0].lower() in ['pyautogui.click', 'pyautogui.rightclick', 'pyautogui.doubleclick'] and re.search(r"\d", action):
                x, y = self._extract_coordinates(action)
                return f"assistantos\npyautogui.click(x={x/1000}, y={y/1000})"
            elif action.split("(")[0].lower() in ['pyautogui.moveto'] and re.search(r"\d", action):
                x, y = self._extract_coordinates(action)
                return f"assistantos\npyautogui.moveto(x={x/1000}, y={y/1000})"
            elif action.split("(")[0].lower() in ['pyautogui.press']:
                return f"assistantos\npyautogui.space()"
            else:
                return None
        
        from utils.schema.aguvisConstants import user_instruction
        data_web = []
        data_desktop = []
        data = self._get_data()
        for key in tqdm(data.keys()):
            from copy import deepcopy
            record = deepcopy(sample)
            data_item = data[key]
            record['goal'] = self._get_task(data_item['task'])
            image_path = data_item['image']
            if 'web' in image_path:
                image_path = re.sub(r"screen_(\d+)", r"screen\1", image_path)
            record['images'] = [os.path.join(self.dataset_path, image_path)]
            try: 
                record['image_size'] = [self._get_image_size(record['images'][0])]
            except Exception as e:
                logger.error(e)
                continue
            action_orig = self._get_action(data_item['task'])
            action_traslate = actionMapping(action_orig)
            if action_traslate is None:
                continue
            record['label'] = action_traslate
            record['messages']['content'][1]['text'] = user_instruction.format(overall_goal=record['goal'], previous_actions="", low_level_instruction="")
            record['is_low_level_instruction'] = False 
            record['mode'] = 'force-plan'
            if 'web' in record['images'][0]:
                data_web.append(record)
            else:
                data_desktop.append(record)
        if not os.path.exists(self.save_path):
            os.makedirs(self.save_path)
        self.saveJson(data_web, os.path.join(self.save_path, self.dataset_type+"_web_"+self.model_name.lower()+'.json'))
        self.saveJson(data_desktop, os.path.join(self.save_path, self.dataset_type+"_desktop_"+self.model_name.lower()+'.json'))
        logger.info("Finished")
          
    def OS_Genesis(self):
        sample = super().OS_Genesis()
        def actionMapping(action):
            if action.split("(")[0].lower() in ['pyautogui.click', 'pyautogui.rightclick', 'pyautogui.doubleclick']:
                if re.search(r"\d", action):
                    x, y = self._extract_coordinates(action)
                    return f'Low-level thought: {""} action: {{"action_type": "click", "x": {x}, "y": {y}}}'
            elif action.split("(")[0].lower() in ['pyautogui.moveto']:
                if re.search(r"\d", action):
                    x, y = self._extract_coordinates(action)
                    if x != -1 and y != -1:
                        return f'Low-level thought: {""} action: {{"action_type": "moveto", "x": {x}, "y": {y}}}'
                    else:
                        return None
            elif action.split("(")[0].lower() in ['pyautogui.press']:
                return f'Low-level thought: {""} action: {{"action_type": "navigate_space"}}'
            else:
                return None
        
        from prompt.omniActPrompt import OS_GENESIS_PROMPT
        data_web = []
        data_desktop = []
        data = self._get_data()
        for key in tqdm(data.keys()):
            from copy import deepcopy
            record = deepcopy(sample)
            data_item = data[key]
            record['goal'] = self._get_task(data_item['task'])
            image_path = data_item['image']
            if 'web' in image_path:
                image_path = re.sub(r"screen_(\d+)", r"screen\1", image_path)
            record['images'] = [os.path.join(self.dataset_path, image_path)]
            try: 
                record['image_size'] = [self._get_image_size(record['images'][0])]
            except Exception as e:
                logger.error(e)
                continue
            action_orig = self._get_action(data_item['task'])
            action_traslate = actionMapping(action_orig)
            if action_traslate is None:
                continue
            record['label'] = action_traslate
            record['question'] = OS_GENESIS_PROMPT.format(instruction=record['goal'], history="", a11y_tree="")  
            if 'web' in record['images'][0]:
                data_web.append(record)
            else:
                data_desktop.append(record)
        if not os.path.exists(self.save_path):
            os.makedirs(self.save_path)
        self.saveJson(data_web, os.path.join(self.save_path, self.dataset_type+"_web_"+self.model_name.lower()+'.json'))
        self.saveJson(data_desktop, os.path.join(self.save_path, self.dataset_type+"_desktop_"+self.model_name.lower()+'.json'))
        logger.info("Finished")


    def GUI_Odyssey(self):
        sample = super().GUI_Odyssey()
        def actionMapping(action, image_size):
            if action.split("(")[0].lower() in ['pyautogui.click', 'pyautogui.rightclick', 'pyautogui.doubleclick', 'pyautogui.moveto'] and re.search(r"\d", action):
                x, y = self._extract_coordinates(action)
                x, y = x / image_size[0] * 1000, y / image_size[1] * 1000
                return f"CLICK: ({int(x)}, {int(y)})"
            elif action.split("(")[0].lower() in ['pyautogui.press']:
                return "PRESS_SPACE"
            else:
                return None
        his_index_web = {}
        his_index_desktop = {}
        data_web = []
        data_desktop = []
        data = self._get_data()
        for key in tqdm(data.keys()):
            from copy import deepcopy
            record = deepcopy(sample)
            data_item = data[key]
            record['goal'] = self._get_task(data_item['task'])
            image_path = data_item['image']
            if 'web' in image_path:
                image_path = re.sub(r"screen_(\d+)", r"screen\1", image_path)
            record['images'] = [os.path.join(self.dataset_path, image_path)]
            try: 
                record['image_size'] = [self._get_image_size(record['images'][0])]
                
            except Exception as e:
                logger.error(e)
                continue
            action_orig = self._get_action(data_item['task'])
            action_traslate = actionMapping(action_orig, record['image_size'][0])
            if action_traslate is None:
                continue
            record['label'] = action_traslate
            question = record['question'].format(instruction=record['goal'], image_path=record['images'][0])
            question += f'\nPrevious screenshots: None'
            question += f'\nPrevious Actions: None'
            question += '\nProvide the command-style action directly.'
            record['question'] = question
            if 'web' in record['images'][0]:
                data_web.append(record)
                his_index_web[f"{record['images'][0]}"] = []
            else:
                data_desktop.append(record)
                his_index_desktop[f"{record['images'][0]}"] = []
        if not os.path.exists(self.save_path):
            os.makedirs(self.save_path)
        self.saveJson(data_web, os.path.join(self.save_path, self.dataset_type+"_web_"+self.model_name.lower()+'.json'))
        self.saveJson(data_desktop, os.path.join(self.save_path, self.dataset_type+"_desktop_"+self.model_name.lower()+'.json'))
        self.saveJson(his_index_web, os.path.join("/Agent_ScanKit/utils/utils_odyssey", f"his_index_web.json"))
        self.saveJson(his_index_desktop, os.path.join("/Agent_ScanKit/utils/utils_odyssey", f"his_index_desktop.json"))
        logger.info("Finished")
    
    def GUI_OWL(self):
        build_system_messages, getResizedImage, build_user_messages, sample = super().GUI_OWL()
        def actionMapping(action, image_size):
            if action.split("(")[0].lower() in ['pyautogui.click', 'pyautogui.rightclick', 'pyautogui.doubleclick'] and re.search(r"\d", action):
                x, y = self._extract_coordinates(action)
                x, y = x / image_size[0] * 1000, y / image_size[1] * 1000
                return f"""<thinking>\n""\n</thinking>\n<tool_call>\n{{"name": "mobile_use", "arguments": {{"action": "click", "coordinate": [{int(x)}, {int(y)}]}}}}\n</tool_call>\n<conclusion>\n""\n</conclusion>"""
            elif action.split("(")[0].lower() in ['pyautogui.moveTo'] and re.search(r"\d", action):
                x, y = self._extract_coordinates(action)
                x, y = x / image_size[0] * 1000, y / image_size[1] * 1000
                return f"""<thinking>\n""\n</thinking>\n<tool_call>\n{{"name": "mobile_use", "arguments": {{"action": "click", "coordinate": [{int(x)}, {int(y)}]}}}}\n</tool_call>\n<conclusion>\n""\n</conclusion>"""
            elif action.split("(")[0].lower() in ['pyautogui.press']:
                return f"""<thinking>\n""\n</thinking>\n<tool_call>\n{{"name": "mobile_use", "arguments": {{"action": "system_button", "button": "Space"}}}}\n</tool_call>\n<conclusion>\n""\n</conclusion>"""
            else:
                return None
        data_web = []
        data_desktop = []
        data = self._get_data()
        for key in tqdm(data.keys()):
            from copy import deepcopy
            record = deepcopy(sample)
            data_item = data[key]
            record['goal'] = self._get_task(data_item['task'])
            image_path = data_item['image']
            if 'web' in image_path:
                image_path = re.sub(r"screen_(\d+)", r"screen\1", image_path)
            record['images'] = [os.path.join(self.dataset_path, image_path)]
            try: 
                record['image_size'] = [self._get_image_size(record['images'][0])]
            except Exception as e:
                logger.error(e)
                continue
            action_orig = self._get_action(data_item['task'])
            action_traslate = actionMapping(action_orig, record['image_size'][0])
            if action_traslate is None:
                continue
            record['label'] = action_traslate
            dummy_image = getResizedImage(record['images'][0])
            system_messages = build_system_messages(dummy_image.height, dummy_image.width)
            user_messages = build_user_messages(record['goal'], enable_think=True, history=[])
            user_messages['content'].append({"image": record['images'][0]})
            messages = [system_messages, user_messages]
            record['messages'] = message_translate(messages, to_format='qwen')
            if 'web' in record['images'][0]:
                data_web.append(record)
            else:
                data_desktop.append(record)
        if not os.path.exists(self.save_path):
            os.makedirs(self.save_path)
        self.saveJson(data_web, os.path.join(self.save_path, self.dataset_type+"_web_"+self.model_name.lower()+'.json'))
        self.saveJson(data_desktop, os.path.join(self.save_path, self.dataset_type+"_desktop_"+self.model_name.lower()+'.json'))
        logger.info("Finished")
            


    def _get_data(self):
        with open(os.path.join(self.test_path), 'rb') as file:
            data = json.load(file)
        return data

    def _get_task(self, path):
        path = os.path.join(self.dataset_path, path)
        with open(path, 'r', encoding='utf-8') as f:
            for line in f:
                task = line.strip().split("Task: ")[-1]
                break        
        return task

    def _get_action(self, path):
        with open(os.path.join(self.dataset_path, path), 'r', encoding='utf-8') as f:
            for line in f:
                if 'pyautogui' in line.strip().lower():
                    return line.strip()

    def _extract_coordinates(self, action):
        match = re.search(r"\((.*)\)", action)
        if match:
            coords = match.group(1).split(",")
            x, y = map(float, coords)
            return x, y
        else:
            return -1, -1
        
    def _get_image_size(self, image_path):
        from PIL import Image
        img = Image.open(image_path).convert('RGB')
        return [img.size[0], img.size[1]]
            


if __name__ == '__main__':
    args = parse_args()
    logger.info(args)
    process = OmniActPreProcess(
            args.dataset_type, args.dataset_path, args.dataset_name, args.save_path, args.model_name)
    if args.model_name == "OS_ATLAS":
        process.OS_ATLAS()
    elif args.model_name == "UI_TARS" or args.model_name == "UI_TARS_1.5":
        process.UI_TARS()
    elif args.model_name == 'GUI_R1':
        process.GUI_R1()
    elif args.model_name == 'Agent_CPM':
        process.Agent_CPM()
    elif args.model_name == 'OS_Genesis':
        process.OS_Genesis()
    elif args.model_name == 'Aguvis':
        process.Aguvis()
    elif args.model_name == 'GUI_Odyssey':
        process.GUI_Odyssey()
    elif args.model_name == 'GUI_OWL':
        process.GUI_OWL()
    else:
        logger.info("error processing")
    
        
        